{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Quantization-Aware Training with BatchNorm Re-estimation\n",
    "\n",
    "This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training) with batchnorm re-estimation.\n",
    "Batchnorm re-estimation is a technique for countering potential instability of batchnorm statistics (i.e. running\n",
    "mean and variance) during QAT. More specifically, batchnorm re-estimation recalculates the batchnorm statistics based on the model after QAT. By doing so, we aim to make our model learn batchnorm statistics from stable outputs after QAT, rather than from likely noisy outputs during QAT.\n",
    "\n",
    "#### Overall flow\n",
    "This notebook covers the following steps:\n",
    "1. Instantiate the example evaluation and training pipeline\n",
    "2. Define Constants and Datasets Prepare\n",
    "3. Create the model in Keras\n",
    "4. Train and evaluate the model\n",
    "5. Quantize the model with QuantSim\n",
    "6. Finetune and evaluate the quantization simulation model\n",
    "7. Re-estimate batchnorm statistics and compare the eval score before and after re-estimation\n",
    "8. Fold the re-estimated batchnorm layers and export the quantization simulation model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "---\n",
    "## Dataset\n",
    "\n",
    "This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. [https://image-net.org/challenges/LSVRC/2012/index.php#](https://image-net.org/challenges/LSVRC/2012/index.php#))\n",
    "\n",
    "**Note**: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.\n",
    "\n",
    "Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASET_DIR = '/path/to/dir/'       # Please replace this with a real directory"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "import tensorflow as tf"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 1. Instantiate the example evaluation and training pipeline\n",
    "\n",
    "The following is an example training and validation loop for this image classification task.\n",
    "\n",
    "- **Does AIMET have any limitations on how the training, validation pipeline is written?** Not really. We will see later that AIMET will modify the user's model to create a QuantizationSim model which is still a TensorFlow model. This QuantizationSim model can be used in place of the original model when doing inference or training.\n",
    "- **Does AIMET put any limitation on the interface of evaluate() or train() methods?** Not really. You should be able to use your existing evaluate and train routines as-is."
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "from typing import Optional\n",
    "from Examples.common import image_net_config\n",
    "from Examples.tensorflow.utils.keras.image_net_dataset import ImageNetDataset\n",
    "from Examples.tensorflow.utils.keras.image_net_evaluator import ImageNetEvaluator\n",
    "\n",
    "\n",
    "class ImageNetDataPipeline:\n",
    "    \"\"\"\n",
    "    Provides APIs for model evaluation and finetuning using ImageNet Dataset.\n",
    "    \"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    def get_val_dataset(batch_size: Optional[int] = None) -> tf.data.Dataset:\n",
    "        \"\"\"\n",
    "        Instantiates a validation dataloader for ImageNet dataset and returns it\n",
    "        :return: A tensorflow dataset\n",
    "        \"\"\"\n",
    "        if batch_size is None:\n",
    "            batch_size = image_net_config.evaluation['batch_size']\n",
    "\n",
    "        data_loader = ImageNetDataset(DATASET_DIR,\n",
    "                                      image_size=image_net_config.dataset['image_size'],\n",
    "                                      batch_size=batch_size)\n",
    "\n",
    "        return data_loader\n",
    "\n",
    "    @staticmethod\n",
    "    def evaluate(model, iterations=None) -> float:\n",
    "        \"\"\"\n",
    "        Given a Keras model, evaluates its Top-1 accuracy on the validation dataset\n",
    "        :param model: The Keras model to be evaluated.\n",
    "        :param iterations: The number of iterations to run. If None, all the data will be used\n",
    "        :return: The accuracy for the sample with the maximum accuracy.\n",
    "        \"\"\"\n",
    "        evaluator = ImageNetEvaluator(DATASET_DIR,\n",
    "                                      image_size=image_net_config.dataset[\"image_size\"],\n",
    "                                      batch_size=image_net_config.evaluation[\"batch_size\"])\n",
    "\n",
    "        return evaluator.evaluate(model=model, iterations=iterations)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2. Define Constants and Datasets Prepare\n",
    "\n",
    "In this section the constants and helper functions needed to run this example are defined.\n",
    "\n",
    "- **EVAL_DATASET_SIZE** To execute this example faster this value has been set to 4\n",
    "- **TRAIN_DATASET_SIZE** To execute this example faster this value has been set to 4\n",
    "- **RE_ESTIMATION_DATASET_SIZE** To execute this example faster this value has been set to 4\n",
    "- **BATCH_SIZE** User sets the batch size. As an example, set to 16\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "EVAL_DATASET_SIZE = 4\n",
    "TRAIN_DATASET_SIZE = 4\n",
    "RE_ESTIMATION_DATASET_SIZE = 4\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "dataset = ImageNetDataPipeline.get_val_dataset(BATCH_SIZE).dataset\n",
    "eval_dataset = dataset.take(EVAL_DATASET_SIZE)\n",
    "train_dataset = dataset.take(TRAIN_DATASET_SIZE)\n",
    "unlabeled_dataset = dataset.map(lambda images, labels: images)\n",
    "re_estimation_dataset = unlabeled_dataset.take(RE_ESTIMATION_DATASET_SIZE)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2. Create the model in Keras\n",
    "\n",
    "Currently, only Keras models built using the Sequential or Functional APIs are compatible with QuantSim - models making use of subclassed layers are incompatible. Therefore, we use the Functional API to create the model used in this example"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "tf.keras.backend.clear_session()\n",
    "inputs = tf.keras.Input(shape=(224, 224, 3), name=\"inputs\")\n",
    "conv = tf.keras.layers.Conv2D(16, (3, 3), name ='conv1')(inputs)\n",
    "bn = tf.keras.layers.BatchNormalization(fused=True)(conv)\n",
    "relu = tf.keras.layers.ReLU()(bn)\n",
    "pool = tf.keras.layers.MaxPooling2D()(relu)\n",
    "conv2 = tf.keras.layers.Conv2D(8, (3, 3), name ='conv2')(pool)\n",
    "flatten = tf.keras.layers.Flatten()(conv2)\n",
    "dense  = tf.keras.layers.Dense(1000)(flatten)\n",
    "functional_model = tf.keras.Model(inputs=inputs, outputs=dense)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 3. Train and evaluate the model\n",
    "\n",
    "Before we can quantize the model and apply QAT, the FP32 model must be trained so that we can get a baseline accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "loss_fn = tf.keras.losses.CategoricalCrossentropy()\n",
    "\n",
    "functional_model.compile(optimizer='adam',\n",
    "              loss=loss_fn,\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "functional_model.fit(train_dataset, epochs=5)\n",
    "\n",
    "# Evaluate the model on the test data using `evaluate`\n",
    "print(\"Evaluate quantized model (post QAT) on test data\")\n",
    "ImageNetDataPipeline.evaluate(model=functional_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 4. Create a QuantizationSim Model\n",
    "\n",
    "Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.\n",
    "A few of the parameters are explained here\n",
    "- **quant_scheme**: We set this to \"QuantScheme.training_range_learning_with_tf_init\"\n",
    "    - Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced\n",
    "- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision\n",
    "- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "from aimet_common.defs import QuantScheme\n",
    "from aimet_tensorflow.keras.quantsim import QuantizationSimModel\n",
    "\n",
    "default_config_per_channel = {\n",
    "        \"defaults\":\n",
    "            {\n",
    "                \"ops\":\n",
    "                    {\n",
    "                        \"is_output_quantized\": \"True\"\n",
    "                    },\n",
    "                \"params\":\n",
    "                    {\n",
    "                        \"is_quantized\": \"True\",\n",
    "                        \"is_symmetric\": \"True\"\n",
    "                    },\n",
    "                \"strict_symmetric\": \"False\",\n",
    "                \"unsigned_symmetric\": \"True\",\n",
    "                \"per_channel_quantization\": \"True\"\n",
    "            },\n",
    "\n",
    "        \"params\":\n",
    "            {\n",
    "                \"bias\":\n",
    "                    {\n",
    "                        \"is_quantized\": \"False\"\n",
    "                    }\n",
    "            },\n",
    "\n",
    "        \"op_type\":\n",
    "            {\n",
    "                \"Squeeze\":\n",
    "                    {\n",
    "                        \"is_output_quantized\": \"False\"\n",
    "                    },\n",
    "                \"Pad\":\n",
    "                    {\n",
    "                        \"is_output_quantized\": \"False\"\n",
    "                    },\n",
    "                \"Mean\":\n",
    "                    {\n",
    "                        \"is_output_quantized\": \"False\"\n",
    "                    }\n",
    "            },\n",
    "\n",
    "        \"supergroups\":\n",
    "            [\n",
    "                {\n",
    "                    \"op_list\": [\"Conv\", \"Relu\"]\n",
    "                },\n",
    "                {\n",
    "                    \"op_list\": [\"Conv\", \"Clip\"]\n",
    "                },\n",
    "                {\n",
    "                    \"op_list\": [\"Conv\", \"BatchNormalization\", \"Relu\"]\n",
    "                },\n",
    "                {\n",
    "                    \"op_list\": [\"Add\", \"Relu\"]\n",
    "                },\n",
    "                {\n",
    "                    \"op_list\": [\"Gemm\", \"Relu\"]\n",
    "                }\n",
    "            ],\n",
    "\n",
    "        \"model_input\":\n",
    "            {\n",
    "                \"is_input_quantized\": \"True\"\n",
    "            },\n",
    "\n",
    "        \"model_output\":\n",
    "            {}\n",
    "    }\n",
    "\n",
    "with open(\"/tmp/default_config_per_channel.json\", \"w\") as f:\n",
    "    json.dump(default_config_per_channel, f)\n",
    "\n",
    "\n",
    "qsim = QuantizationSimModel(functional_model, quant_scheme=QuantScheme.training_range_learning_with_tf_init,\n",
    "                                config_file=\"/tmp/default_config_per_channel.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prepare the evaluation callback function\n",
    "\n",
    "The **eval_callback()** function takes the model object to evaluate and compile option dictionary and the number of samples to use as arguments. If the **num_samples** argument is None, the whole evaluation dataset is used to evaluate the model."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "\n",
    "\n",
    "def eval_callback(model: tf.keras.Model,\n",
    "                  num_samples: Optional[int] = None) -> float:\n",
    "    if num_samples is None:\n",
    "        num_samples = EVAL_DATASET_SIZE\n",
    "\n",
    "    sampled_dataset = eval_dataset.take(num_samples)\n",
    "\n",
    "    # Model should be compiled before evaluation\n",
    "    model.compile(optimizer=tf.keras.optimizers.Adam(),\n",
    "                  loss=tf.keras.losses.CategoricalCrossentropy(),\n",
    "                  metrics=tf.keras.metrics.CategoricalAccuracy())\n",
    "    _, acc = model.evaluate(sampled_dataset)\n",
    "\n",
    "    return acc"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "**Compute Encodings**\n",
    "\n",
    "Even though AIMET has added 'quantizer' nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each 'quantizer' node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as 'computing encodings'.\n",
    "\n",
    "So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don't need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples\n",
    "- In practice, we need a very small percentage of the overall data samples for computing encodings.\n",
    "- It may be beneficial if the samples used for computing encoding are well distributed. It's not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all positive or negative samples are used.\n",
    "\n",
    "The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qsim.compute_encodings(eval_callback, forward_pass_callback_args=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Next, we can evaluate the performance of the quantized model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print(\"Evaluate quantized model on test data\")\n",
    "ImageNetDataPipeline.evaluate(model=qsim.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 5. Perform QAT\n",
    "\n",
    "To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.\n",
    "For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "quantized_callback = tf.keras.callbacks.TensorBoard(log_dir=\"./log/quantized\")\n",
    "history = qsim.model.fit(\n",
    "    train_dataset, batch_size=4, epochs=1, validation_data=eval_dataset,\n",
    "    callbacks=[quantized_callback]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's evaluate the validation accuracy of our model after QAT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "print(\"Evaluate quantized model (post QAT) on test data\")\n",
    "ImageNetDataPipeline.evaluate(model=qsim.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***6. Re-estimate BatchNorm Statistics***\n",
    "\n",
    "AIMET provides a helper function, `reestimate_bn_stats`, for re-estimating batchnorm statistics.\n",
    "Here is the full list of parameters for this function:\n",
    "* **model**: Model to re-estimate the BatchNorm statistics.\n",
    "* **dataloader** Train dataloader.\n",
    "* **num_batches** (optional): The number of batches to be used for reestimation. (Default: 100)\n",
    "* **forward_fn** (optional): Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader. If not specified, it is expected that inputs yielded from dataloader can be passed directly to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.bn_reestimation import reestimate_bn_stats\n",
    "\n",
    "reestimate_bn_stats(qsim.model, re_estimation_dataset, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fold BatchNorm Layers\n",
    "\n",
    "So far, we have improved our quantization simulation model through QAT and batchnorm re-estimation. The next step would be to actually take this model to target. But first, we should fold the batchnorm layers for our model to run on target devices more efficiently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms_to_scale\n",
    "fold_all_batch_norms_to_scale(qsim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 5. Export Model\n",
    "As the final step, we will export the model to run it on actual target devices. AIMET QuantizationSimModel provides an export API for this purpose."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.makedirs('./output/', exist_ok=True)\n",
    "qsim.export(path='./output/', filename_prefix='mnist_after_bn_re_estimation_qat_range_learning')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Hope this notebook was useful for you to understand how to use batchnorm re-estimation feature of AIMET.\n",
    "\n",
    "Few additional resources\n",
    "- Refer to the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) to know more details of the APIs and optional parameters.\n",
    "- Refer to the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/tensorflow/quantization/keras) to understand how to use AIMET post-training quantization techniques and QAT methods."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}